""" Multi-objective optimization: individuals.
"""

import numpy as np
import scipy as sp
import deap, deap.base, deap.tools
from etsynseg import pcdutil, bspline


class MOOFitness(deap.base.Fitness):
    """ DEAP fitness for individuals.
    """
    # (coverage penalty, non-membrane penalty)
    weights = (-1, -1)


class MOOIndiv(list):
    """ DEAP individuals.

    Each instance contains indexes of sampling points on the grids,
    [[i_u0v0,i_u0v1,...],[i_u1v0,i_u1v0,...],...].

    Attributes:
        fitness (MOOFitness): 
    """
    def __init__(self, iterable=()):
        super().__init__(iterable)
        self.fitness = MOOFitness()


class MOOTools:
    """ Tools for individuals.

    Examples:
        # init
        mtools = MOOTools().init_from_grid(grid, fitness_rthresh)
        config = mtools.get_config()
        mtools_new = MOOTools().init_from_confiv(config)
        # generate individuals
        indiv = mtools.gen_random()
        # evolution
        mtools.mutate(indiv)
        indiv.fitness.values = mtools.evaluate(indiv)
        # save/load
        sample_list, fitness = mtools.indiv_to_simple(indiv)
        indiv = mtools.indiv_from_simple(sample_list, fitness)
        # evaluation
        zyx_fit, surf_fit = mtools.fit_surface(indiv, u_evals, v_evals)
        fitness = mtools.evaluate(indiv, u_evals, v_evals)
    
    Attributes:
        zyx (np.ndarray): 3d points to be fit. Shape=(npts,3), and each point is [zi,yi,xi].
        n_uz, n_vxy (int): The number of grids in u(z),v(xy) directions.
        uv_size: uv_size[(iu,iv)] is the number of points in the grid.
        uv_zyx: uv_zyx[(iu,iv)] is the array of point coordinates (in [z,y,x]) in the grid.
        fitness_rthresh (float): Distance threshold for fitness calculation. Beyond which the loss becomes constant.

    Methods:
        # init
        init_from_grid, init_from_config
        # io
        get_config, indiv_to_simple, indiv_from_simple
        # get coordinates
        flatten_net, get_coords_net, get_coords_flat
        # indiv generation
        gen_random, gen_uniform, gen_middle
        # operations
        mutate
        # evaluate
        fit_surface, calc_fitness, evaluate
    """
    #=========================
    # init, save/load
    #=========================
    
    def __init__(self):
        """ Actual init is done by self.init_from_grid or self.init_from_config.
        """
        pass
    
    def init_from_grid(self, grid, fitness_rthresh):
        """ Initialize attributes from Grid object.

        Args:
            grid (dict): Config of moosac.Grid. Keys={zyx,count_zyx,n_uz,n_vxy,uv_zyx,uv_size}.
            fitness_rthresh (float): Distance threshold for fitness calculation.

        Returns:
            self (MOOTools): Self object whose attributes are set.
        """
        self.zyx = grid["zyx"]
        self.count_zyx = grid["count_zyx"]
        self.n_uz = grid["n_uz"]
        self.n_vxy = grid["n_vxy"]
        self.uv_size = grid["uv_size"]
        self.uv_zyx = grid["uv_zyx"]
        self.fitness_rthresh = fitness_rthresh
        self.init_common()
        return self

    def init_from_config(self, config):
        """ Initialize attributes from config.

        Args:
            config (dict): Config generated by self.get_config. Contains attributes.

        Returns:
            self (MOOTools): Self object whose attributes are set.
        """
        self.zyx = config["zyx"]
        self.count_zyx = config["count_zyx"]
        self.n_uz = config["n_uz"]
        self.n_vxy = config["n_vxy"]
        self.uv_size = config["uv_size"]
        self.uv_zyx = config["uv_zyx"]
        self.fitness_rthresh = config["fitness_rthresh"]
        self.init_common()
        return self

    def init_common(self):
        """ Common initializations in addition to the attributes.
        """
        # bspline
        self.surf_meta = bspline.Surface(degree=2)
        # pointcloud
        self.zyx = pcdutil.points_deduplicate(self.zyx)
        self.kdtree = sp.spatial.KDTree(self.zyx)

    def get_config(self):
        """ Convert attributes to dict.
        
        Convenient for dumping and loading.

        Returns:
            config (dict): Dict of attributes.
                {zyx,count_zyx,n_vxy,n_uz,uv_size,uv_zyx,fitness_rthresh}.
        """
        config = dict(
            zyx=self.zyx,
            count_zyx=self.count_zyx,
            n_uz=self.n_uz,
            n_vxy=self.n_vxy,
            uv_size=self.uv_size,
            uv_zyx=self.uv_zyx,
            fitness_rthresh=self.fitness_rthresh
        )
        return config

    def indiv_to_simple(self, indiv):
        """ Convert individual to simple formats for easier dumping and loading.

        Args:
            indiv (MOOIndiv): Individual.

        Returns:
            sample_list (list): List of sampling point indexes in each grid.
            fitness (float): Fitness of the individual.
        """
        sample_list = list(indiv)
        fitness = indiv.fitness.values
        return sample_list, fitness

    def indiv_from_simple(self, sample_list, fitness=None):
        """ Generate individual from sample list and fitness.

        Args:
            sample_list (list): List of sampling point indexes in each grid.
            fitness (float): Fitness.

        Returns:
            indiv (MOOIndiv): Individual.
        """
        indiv = MOOIndiv(sample_list)
        if fitness is not None:
            indiv.fitness.values = fitness
        return indiv

    #=========================
    # conversion: indiv, points
    #=========================

    def flatten_net(self, net):
        """ Flatten net-shaped sampling points.
        
        Args:
            net (np.ndarray): Points with shape=(n_uz, n_vxy, 3).
        Returns: flat
            flat: shape=(n_uz*n_vxy, 3)
        """
        return net.reshape((-1, 3))

    def get_coords_net(self, indiv):
        """ Get sample coordinates in a net-shaped form from individual.

        Args:
            indiv (MOOIndiv): Individual.

        Returns:
            zyx_net (np.ndarray): Sample points arranged in a net shape, with shape=(n_uz,n_vxy,3).
        """
        zyx_net = np.zeros((self.n_uz, self.n_vxy, 3))
        for iu in range(self.n_uz):
            for iv in range(self.n_vxy):
                lb_i = indiv[iu][iv]
                zyx_net[iu][iv] = self.uv_zyx[(iu, iv)][lb_i]
        return zyx_net

    def get_coords_flat(self, indiv):
        """ Get sample coordinates in a flattened form from individual.

        Args:
            indiv (MOOIndiv): Individual.

        Returns:
            zyx_flat (np.ndarray): Sample points with shape=(n_uz*n_vxy,3).
        """
        zyx_net = self.get_coords_net(indiv)
        zyx_flat = self.flatten_net(zyx_net)
        return zyx_flat

    #=========================
    # indiv generation
    #=========================
    
    def gen_random(self):
        """ Generate individual with random sampling in each grid.

        Returns:
            indiv (MOOIndiv): Individual.
        """
        indiv = MOOIndiv()
        for iu in range(self.n_uz):
            indiv_u = []
            for iv in range(self.n_vxy):
                indiv_uv = np.random.randint(self.uv_size[(iu, iv)])
                indiv_u.append(indiv_uv)
            indiv.append(indiv_u)
        return indiv

    def gen_uniform(self, index=0):
        """ Generate individual with uniform index in each grid.

        Args:
            index (int): The index to select. Will clip according to grid size.

        Returns:
            indiv (MOOIndiv): Individual.
        """
        indiv = MOOIndiv()
        for iu in range(self.n_uz):
            indiv_u = []
            for iv in range(self.n_vxy):
                indiv_uv = np.clip(index, 0, self.uv_size[(iu, iv)]-1)
                indiv_u.append(indiv_uv)
            indiv.append(indiv_u)
        return indiv
    
    def gen_middle(self, pin_side=False):
        """ Generate individual with the middle point in each grid.

        Args:
            pin_side (bool): If True, points in side grids are pinned to the boundaries.

        Returns:
            indiv (MOOIndiv): Individual.
        """
        indiv = MOOIndiv()
        for iu in range(self.n_uz):
            indiv_u = []
            for iv in range(self.n_vxy):
                size_uv = self.uv_size[(iu, iv)]
                # set index to middle
                index = int((size_uv-1)/2)
                # for side grids, optionally set index to the boundary
                if pin_side:
                    if iv == 0:
                        index = 0
                    elif iv == self.n_vxy-1:
                        index = size_uv-1
                indiv_uv = np.clip(index, 0, size_uv-1)
                indiv_u.append(indiv_uv)
            indiv.append(indiv_u)
        return indiv

    #=========================
    # mutation, evaluation
    #=========================
    
    def mutate(self, indiv):
        """ Mutate individual in-place. Randomly resample one of the grids.

        Args:
            indiv (MOOIndiv): Individual.
        Returns: None
        """
        # select one grid to mutate
        iu = np.random.randint(0, self.n_uz)
        iv = np.random.randint(0, self.n_vxy)
        # randomly select one sample from the grid
        indiv[iu][iv] = np.random.randint(self.uv_size[(iu, iv)])
  
    def fit_surface(self, indiv, u_eval=None, v_eval=None, factor_eval=1):
        """ Fit surface from individual, evaluate at net.

        Args:
            indiv (MOOIndiv): Individual.
            u_eval, v_eval (np.ndarray): 1d arrays of u(z) and v(xy) to evaluate at, which range from [0,1].
                Defaults to sample the max length of wireframes.
            factor_eval (float): The default number of evaluation points = wireframe length * factor_eval.

        Returns:
            zyx_fit (np.ndarray): Evaluated points on the fitted surface, with shape=(npts,3).
            surf_fit (splipy.surface.Surface): Fitted surface.
        """
        # fit
        pts_net = self.get_coords_net(indiv)
        surf_fit = self.surf_meta.interpolate(pts_net)

        # set eval points
        # default: max wireframe length
        if u_eval is None:
            nu_eval = int(np.max(pcdutil.wireframe_length(pts_net, axis=0)))
            u_eval = np.linspace(0, 1, int(factor_eval*nu_eval))
        if v_eval is None:
            nv_eval = int(np.max(pcdutil.wireframe_length(pts_net, axis=1)))
            v_eval = np.linspace(0, 1, int(factor_eval*nv_eval))

        # convert fitted surface to points
        zyx_fit = self.flatten_net(surf_fit(u_eval, v_eval))
        return zyx_fit, surf_fit

    def calc_fitness(self, zyx_fit):
        """ Calculate fitness.

        Args:
            zyx_fit (np.ndarray): Evaluated points on the fitted surface, with shape=(npts,3).
        
        Returns:
            fitness (float): Fitness of the individual.
        """
        # deduplicate, convert to pointcloud
        zyx_fit = pcdutil.points_deduplicate(zyx_fit)
        kdtree_fit = sp.spatial.KDTree(zyx_fit)
        
        # coverage of zyx by fit: dist from zyx to fit
        dist, _ = kdtree_fit.query(self.zyx, k=1, workers=-1)
        dist2 = np.clip(dist, 0, self.fitness_rthresh)**2
        fitness_coverage = np.sum(self.count_zyx*dist2)

        # excessive pixels of fit compared with zyx: dist from fit to zyx
        dist_fit, _ = self.kdtree.query(zyx_fit, k=1, workers=-1)
        fitness_excess = np.sum(dist_fit>self.fitness_rthresh)

        # moo fitness
        fitness = (fitness_coverage, fitness_excess)
        return fitness

    def evaluate(self, indiv, u_eval=None, v_eval=None):
        """ Evaluate fitness of individual. Fit surface then calculate fitness.

        Args:
            indiv (MOOIndiv): Individual.
            u_eval, v_eval (np.ndarray): 1d arrays of u(z) and v(xy) to evaluate at, which range from [0,1].
                The default number of points corresponds to the max length of wireframes.

        Returns:
            fitness (float): Fitness of the individual.
        """
        zyx_fit, _ = self.fit_surface(indiv, u_eval=u_eval, v_eval=v_eval)
        fitness = self.calc_fitness(zyx_fit)
        return fitness
